{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%run main.py\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "!mkdir -p {DATA_DIR} {BERT_DIR} {MODEL_DIR}\n",
    "s3 = S3()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not exists(NE5):\n",
    "    s3.download(S3_NE5, NE5)\n",
    "    s3.download(S3_FACTRU, FACTRU)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not exists(BERT_VOCAB):\n",
    "    s3.download(S3_BERT_VOCAB, BERT_VOCAB)\n",
    "    s3.download(S3_BERT_EMB, BERT_EMB)\n",
    "    s3.download(S3_BERT_ENCODER, BERT_ENCODER)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "words_vocab = BERTVocab.load(BERT_VOCAB)\n",
    "tags_vocab = BIOTagsVocab(TAGS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(SEED)\n",
    "seed(SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = RuBERTConfig()\n",
    "emb = BERTEmbedding.from_config(config)\n",
    "encoder = BERTEncoder.from_config(config)\n",
    "ner = BERTNERHead(config.emb_dim, len(tags_vocab))\n",
    "model = BERTNER(emb, encoder, ner)\n",
    "\n",
    "for param in emb.parameters():\n",
    "    param.requires_grad = False\n",
    "\n",
    "model.emb.load(BERT_EMB)\n",
    "model.encoder.load(BERT_ENCODER)\n",
    "model = model.to(DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "000cbfe4879b482e80183e8396300ef0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "records = []\n",
    "lines = load_gz_lines(CUSTOM_TEXTS) if CUSTOM_TUNING else load_gz_lines(NE5)\n",
    "items = parse_jl(lines)\n",
    "items = log_progress(items)\n",
    "\n",
    "for item in items:\n",
    "    record = SpanMarkup.from_json(item)\n",
    "    tokens = list(tokenize(record.text))\n",
    "    record = record.to_bio(tokens)\n",
    "    records.append(record)\n",
    "\n",
    "size = round(len(records) * 0.2)\n",
    "\n",
    "markups = {\n",
    "    TEST: records[:size],\n",
    "    TRAIN: records[size:]\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "encode = BERTNERTrainEncoder(\n",
    "    words_vocab, tags_vocab,\n",
    "    seq_len=128,\n",
    "    batch_size=32,\n",
    "    shuffle_size=10000\n",
    ")\n",
    "\n",
    "batches = {}\n",
    "for name in [TEST, TRAIN]:\n",
    "    batches[name] = [_.to(DEVICE) for _ in encode(markups[name])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "board = MultiBoard([\n",
    "    TensorBoard(BOARD_NAME, RUNS_DIR),\n",
    "    LogBoard()\n",
    "])\n",
    "boards = {\n",
    "    TRAIN: board.section(TRAIN_BOARD),\n",
    "    TEST: board.section(TEST_BOARD),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optim.Adam([\n",
    "    dict(params=encoder.parameters(), lr=BERT_LR),\n",
    "    dict(params=ner.parameters(), lr=LR),\n",
    "])\n",
    "scheduler = optim.lr_scheduler.ExponentialLR(optimizer, LR_GAMMA)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "meters = {\n",
    "    TRAIN: NERScoreMeter(),\n",
    "    TEST: NERScoreMeter(),\n",
    "}\n",
    "\n",
    "for epoch in log_progress(range(EPOCHS)):\n",
    "    model.train()\n",
    "    for batch in log_progress(batches[TRAIN], leave=False):\n",
    "        optimizer.zero_grad()\n",
    "        batch = process_batch(model, ner.crf, batch)\n",
    "        batch.loss.backward()\n",
    "        optimizer.step()\n",
    "    \n",
    "        score = NERBatchScore(batch.loss)\n",
    "        meters[TRAIN].add(score)\n",
    "\n",
    "    meters[TRAIN].write(boards[TRAIN])\n",
    "    meters[TRAIN].reset()\n",
    "\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        for batch in log_progress(batches[TEST], leave=False, desc=TEST):\n",
    "            batch = process_batch(model, ner.crf, batch)\n",
    "            batch.target = split_masked(batch.target.value, batch.target.mask)\n",
    "            batch.pred = ner.crf.decode(batch.pred.value, batch.pred.mask)\n",
    "            score = score_ner_batch(batch, tags_vocab)\n",
    "            meters[TEST].add(score)\n",
    "\n",
    "        meters[TEST].write(boards[TEST])\n",
    "        meters[TEST].reset()\n",
    "    \n",
    "    scheduler.step()\n",
    "    board.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# [2020-03-31 14:05:40]    0 14.3334 01_train/01_loss\n",
    "# [2020-03-31 14:05:43]    0 2.3965 02_test/01_loss\n",
    "# [2020-03-31 14:05:43]    0 0.9962 02_test/02_PER\n",
    "# [2020-03-31 14:05:43]    0 0.9807 02_test/03_LOC\n",
    "# [2020-03-31 14:05:43]    0 0.9691 02_test/04_ORG\n",
    "# [2020-03-31 14:06:10]    1 1.8448 01_train/01_loss\n",
    "# [2020-03-31 14:06:13]    1 2.1326 02_test/01_loss\n",
    "# [2020-03-31 14:06:13]    1 0.9975 02_test/02_PER\n",
    "# [2020-03-31 14:06:13]    1 0.9862 02_test/03_LOC\n",
    "# [2020-03-31 14:06:13]    1 0.9710 02_test/04_ORG\n",
    "# [2020-03-31 14:06:40]    2 1.2753 01_train/01_loss\n",
    "# [2020-03-31 14:06:43]    2 2.1436 02_test/01_loss\n",
    "# [2020-03-31 14:06:43]    2 0.9972 02_test/02_PER\n",
    "# [2020-03-31 14:06:43]    2 0.9867 02_test/03_LOC\n",
    "# [2020-03-31 14:06:43]    2 0.9705 02_test/04_ORG\n",
    "# [2020-03-31 14:07:10]    3 1.1283 01_train/01_loss\n",
    "# [2020-03-31 14:07:13]    3 2.1885 02_test/01_loss\n",
    "# [2020-03-31 14:07:13]    3 0.9975 02_test/02_PER\n",
    "# [2020-03-31 14:07:13]    3 0.9867 02_test/03_LOC\n",
    "# [2020-03-31 14:07:13]    3 0.9719 02_test/04_ORG\n",
    "# [2020-03-31 14:07:40]    4 1.0464 01_train/01_loss\n",
    "\n",
    "# [2020-03-31 14:07:43]    4 2.1705 02_test/01_loss\n",
    "# [2020-03-31 14:07:43]    4 0.9977 02_test/02_PER\n",
    "# [2020-03-31 14:07:43]    4 0.9862 02_test/03_LOC\n",
    "# [2020-03-31 14:07:43]    4 0.9722 02_test/04_ORG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.encoder.dump(MODEL_ENCODER)\n",
    "ner.dump(MODEL_NER)\n",
    "\n",
    "# s3.upload(MODEL_ENCODER, S3_MODEL_ENCODER)\n",
    "# s3.upload(MODEL_NER, S3_MODEL_NER)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "fbd3b3cf5ce5dbcb71d33f7b8a90c542bb07cc48175c202e830100849640f809"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
